#!/usr/bin/env python
"""Generate requirements/*.txt files from pyproject.toml."""

import sys
from pathlib import Path

try:  # standard module since Python 3.11
    import tomllib as toml
except ImportError:
    try:  # available for older Python via pip
        import tomli as toml
    except ImportError:
        sys.exit("Please install `tomli` first: `{mamba, pip} install tomli`")

script_pth = Path(__file__)
repo_dir = script_pth.parent.parent
script_relpth = script_pth.relative_to(repo_dir)
header = [
    f"# Generated via {script_relpth.as_posix()}.",
    "# Do not edit this file; modify `pyproject.toml` instead "
    "and run `python tools/generate_requirements.py`.",
]


def generate_requirement_file(name, req_list, *, extra_list=None):
    req_fname = repo_dir / "requirements" / f"{name}.txt"

    # remove once scikit-umfpack issues are resolved
    comment = "# scikit-umfpack  # circular dependency issues"
    req_list = [comment if x == "scikit-umfpack" else x for x in req_list]

    # remove once gmpy2 supports Python 3.12
    comment = "# gymp2  # does not yet support Python 3.12"
    req_list = [comment if x == "gmpy2" else x for x in req_list]

    if name == "build":
        req_list = [x for x in req_list if "numpy" not in x]
        req_list.append("ninja")

    if extra_list:
        req_list += extra_list

    req_fname.write_text("\n".join(header + req_list) + "\n")
    return req_list


def main():
    pyproject = toml.loads((repo_dir / "pyproject.toml").read_text())

    default = generate_requirement_file("default", pyproject["project"]["dependencies"])
    generate_requirement_file("build", pyproject["build-system"]["requires"],
                              extra_list=default)

    for key, opt_list in pyproject["project"]["optional-dependencies"].items():
        generate_requirement_file(key, opt_list)

    # generate requirements/all.txt
    all_path = repo_dir / "requirements" / "all.txt"
    files = ["build", "dev", "doc", "test"]
    reqs = [f"-r {x}.txt" for x in files]
    all_path.write_text("\n".join(header + reqs) + "\n")


if __name__ == "__main__":
    main()
